import math
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init, Parameter
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal


def stack_X(x, p, q):
    """Stack x for X."""
    x = x.view(-1, 1)
    X = []
    for i in range(p):
        X_p = []
        for j in range(p):
            if i == j:
                X_p.append(x)
            else:
                X_p.append(torch.zeros(q, 1))
        X.append(torch.cat(X_p))

    return torch.cat((torch.cat(X, dim=1).T, torch.eye(p)), dim=1)


class OnPolicyCoherentLinear(nn.Linear):
    def __init__(
        self, in_features, out_features, beta=0.01, std_w_init=0.017, std_a=0.1,
    ):
        super(OnPolicyCoherentLinear, self).__init__(
            in_features, out_features, bias=True
        )
        self.in_dim = in_features
        self.out_dim = out_features
        self.beta = beta
        self.cov_a = torch.eye(self.out_dim) * std_a ** 2
        self.prec_a = torch.inverse(self.cov_a)

        # init log_std_w
        log_std_w = np.log(std_w_init) * np.ones(
            self.in_dim * self.out_dim + self.out_dim, dtype=np.float32
        )
        self.log_std_w = torch.nn.Parameter(torch.as_tensor(log_std_w))

        self.reset()

    def forward(self, x):
        if self.if_with_noise:
            W, b = self.sample_w()
            self.x = x
            return F.linear(x, W, b)
        else:
            return F.linear(x, self.weight, self.bias)

    def sample_w(self):
        if self.prev_w is None:
            p_w_0 = MultivariateNormal(self.mean_w, self.cov_w)
            w = p_w_0.sample()
        else:
            if self.beta > 0:
                w = self.beta * self.p_w_tilde.sample() + (1 - self.beta) * self.prev_w
            else:
                w = self.prev_w * 1.0

        self.prev_w = w

        W = torch.reshape(w[: -self.out_dim], (self.out_dim, self.in_dim))
        b = w[-self.out_dim :]

        return W, b

    def infer_post_p_w(self):
        if self.prev_post_mean_w is None:
            post_mean_w = self.mean_w * 1.0
            post_cov_w = self.cov_w * 1.0
        else:
            X_prev = stack_X(self.prev_x, self.out_dim, self.in_dim)

            Sigma = torch.inverse(
                self.prev_post_prec_w + X_prev.T @ self.prec_a @ X_prev
            )

            u = Sigma @ (
                X_prev.T @ self.prec_a @ self.prev_a
                + self.prev_post_prec_w @ self.prev_post_mean_w
            )

            post_mean_w = (1 - self.beta) * u + self.beta * self.mean_w
            post_cov_w = (2 * self.beta - self.beta ** 2) * self.cov_w + (
                1 - self.beta
            ) ** 2 * Sigma

        return post_mean_w, post_cov_w

    def get_marginal_pi(self):
        X = stack_X(self.x, self.out_dim, self.in_dim)

        post_mean_w, post_cov_w = self.infer_post_p_w()

        marginal_mean_a = X @ post_mean_w
        marginal_cov_a = self.cov_a + X @ post_cov_w @ X.T
        marginal_pi = MultivariateNormal(marginal_mean_a, marginal_cov_a)

        # update data for inference
        self.prev_x = self.x
        self.prev_a = self.a
        self.prev_post_mean_w = post_mean_w
        self.prev_post_prec_w = torch.inverse(post_cov_w)

        return marginal_pi

    def get_a(self, a):
        self.a = a

    def get_mean_std_w(self):
        return torch.exp(self.log_std_w.detach()).mean().item()

    def get_entropy_p_w(self):
        with torch.no_grad():
            return MultivariateNormal(self.mean_w, self.cov_w).entropy().item()

    def reset(self):
        # with or without parameter noise
        self.if_with_noise = True

        # empty data for inference
        self.prev_w = None
        self.x = None
        self.prev_x = None
        self.a = None
        self.prev_a = None
        self.prev_post_mean_w = None
        self.prev_post_prec_w = None

        self.mean_w = torch.cat((self.weight.view(-1), self.bias), 0)
        self.cov_w = torch.diag(torch.exp(self.log_std_w) ** 2)
        self.prec_w = torch.inverse(self.cov_w)

        if self.beta > 0:
            self.p_w_tilde = MultivariateNormal(
                self.mean_w, (2 / self.beta - 1) * self.cov_w
            )


class OffPolicyCoherentLinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        beta=0.01,
        std_w_init=0.017,
        alpha=1.01,
        mse_threshold=0.1,
    ):
        super(OffPolicyCoherentLinear, self).__init__(
            in_features, out_features, bias=True
        )
        self.in_dim = in_features
        self.out_dim = out_features
        self.beta = beta
        self.std_w = std_w_init
        self.alpha = alpha
        self.mse_threshold = mse_threshold
        self.reset()

    def forward(self, x):
        if self.if_with_noise:
            W, b = self.sample_w()
            return F.linear(x, W, b)
        else:
            return F.linear(x, self.weight, self.bias)

    def adapt_std_w(self, distance):
        if distance > self.mse_threshold:
            self.std_w /= self.alpha
        else:
            self.std_w *= self.alpha

    def sample_w(self):
        if self.prev_w is None:
            p_w_0 = Normal(self.mean_w, self.diag_std_w)
            w = p_w_0.sample()
        else:
            if self.beta > 0:
                w = self.beta * self.p_w_tilde.sample() + (1 - self.beta) * self.prev_w
            else:
                w = self.prev_w * 1.0

        self.prev_w = w

        W = torch.reshape(w[: -self.out_dim], (self.out_dim, self.in_dim))
        b = w[-self.out_dim :]

        return W, b

    def reset(self):
        self.if_with_noise = True
        self.prev_w = None
        self.mean_w = torch.cat((self.weight.view(-1), self.bias), 0).detach()
        self.diag_std_w = self.std_w * torch.ones(
            self.in_dim * self.out_dim + self.out_dim
        )
        if self.beta > 0:
            self.p_w_tilde = Normal(
                self.mean_w, np.sqrt(2 / self.beta - 1) * self.diag_std_w
            )


class NoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, std_w_init=0.017, bias=True):
        super(NoisyLinear, self).__init__(in_features, out_features, bias=True)
        self.std_w_init = std_w_init
        self.std_W = Parameter(torch.Tensor(out_features, in_features))
        self.std_b = Parameter(torch.Tensor(out_features))
        self.register_buffer("noise_W", torch.zeros(out_features, in_features))
        self.register_buffer("noise_b", torch.zeros(out_features))
        self.reset_parameters()
        self.reset()

    def reset_parameters(self):
        if hasattr(self, "std_W"):
            init.uniform_(
                self.weight,
                -math.sqrt(3 / self.in_features),
                math.sqrt(3 / self.in_features),
            )
            init.uniform_(
                self.bias,
                -math.sqrt(3 / self.in_features),
                math.sqrt(3 / self.in_features),
            )
            init.constant_(self.std_W, self.std_w_init)
            init.constant_(self.std_b, self.std_w_init)

    def forward(self, input):
        if self.if_with_noise:
            return F.linear(
                input,
                self.weight + self.std_W * self.noise_W,
                self.bias + self.std_b * self.noise_b,
            )
        else:
            return F.linear(input, self.weight, self.bias)

    def sample_noise(self):
        self.noise_W = torch.randn(self.out_features, self.in_features)
        self.noise_b = torch.randn(self.out_features)

    def reset(self):
        self.if_with_noise = True
        self.sample_noise()


class OurNoisyLinear(nn.Linear):
    def __init__(self, in_features, out_features, std_w_init=0.017, bias=True):
        super(OurNoisyLinear, self).__init__(in_features, out_features, bias=True)
        self.std_w_init = std_w_init
        self.log_std_W = Parameter(torch.Tensor(out_features, in_features))
        self.log_std_b = Parameter(torch.Tensor(out_features))
        self.register_buffer("noise_W", torch.zeros(out_features, in_features))
        self.register_buffer("noise_b", torch.zeros(out_features))
        self.reset_parameters()
        self.reset()

    def reset_parameters(self):
        if hasattr(self, "log_std_W"):
            init.uniform_(
                self.weight,
                -math.sqrt(3 / self.in_features),
                math.sqrt(3 / self.in_features),
            )
            init.uniform_(
                self.bias,
                -math.sqrt(3 / self.in_features),
                math.sqrt(3 / self.in_features),
            )
            init.constant_(self.log_std_W, np.log(self.std_w_init))
            init.constant_(self.log_std_b, np.log(self.std_w_init))

    def forward(self, input):
        if self.if_with_noise:
            return F.linear(
                input,
                self.weight + torch.exp(self.log_std_W) * self.noise_W,
                self.bias + torch.exp(self.log_std_b) * self.noise_b,
            )
        else:
            return F.linear(input, self.weight, self.bias)

    def get_mean_std_w(self):
        return (
            torch.exp(torch.cat((self.log_std_W.view(-1), self.log_std_b), 0))
            .detach()
            .mean()
            .item()
        )

    def sample_noise(self):
        self.noise_W = torch.randn(self.out_features, self.in_features)
        self.noise_b = torch.randn(self.out_features)

    def reset(self):
        self.if_with_noise = True
        self.sample_noise()


class PSNELinear(nn.Linear):
    def __init__(
        self,
        in_features,
        out_features,
        std_w_init=0.017,
        alpha=1.01,
        kl_threshold=0.01,
        mse_threshold=0.1,
    ):
        super(PSNELinear, self).__init__(in_features, out_features, bias=True)
        self.std_w = std_w_init
        self.alpha = alpha
        self.kl_threshold = kl_threshold
        self.mse_threshold = mse_threshold
        self.norm_layer = nn.LayerNorm(out_features)
        self.reset()

    def forward(self, input):
        if self.if_with_noise:
            x = F.linear(
                input,
                self.weight + self.std_w * self.noise_W,
                self.bias + self.std_w * self.noise_b,
            )
        else:
            x = F.linear(input, self.weight, self.bias)

        return self.norm_layer(x)

    def adapt_std_w(self, distance, if_pi_stochastic=True):
        if if_pi_stochastic:
            if distance > self.kl_threshold:
                self.std_w /= self.alpha
            else:
                self.std_w *= self.alpha
        else:
            if distance > self.mse_threshold:
                self.std_w /= self.alpha
            else:
                self.std_w *= self.alpha

    def sample_noise(self):
        self.noise_W = torch.randn(self.out_features, self.in_features)
        self.noise_b = torch.randn(self.out_features)

    def reset(self):
        self.if_with_noise = True
        self.sample_noise()
